-
Notifications
You must be signed in to change notification settings - Fork 31.8k
fix: ignore padding tokens in Bart loss #7828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I don't seem to be able to add reviewers but I guess this would fall into the domain of @sshleifer. |
|
Thx for the contribution! FYI we have a seq2seq finetuner with this bugfix. I worked on this at some point and thought I had fixed it. Any issue with merging this @patil-suraj @patrickvonplaten ? |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
|
LGTM. but should be documented. seen few notebooks where people are setting pad tokens to -100 in labels . We should change this for T5 as well |
| if labels is not None: | ||
| loss_fct = CrossEntropyLoss() | ||
| # TODO(SS): do we need to ignore pad tokens in labels? | ||
| loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert if there are -100 in labels
Wdyt @sshleifer ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea. We should also probably do FSMT
Thanks, I did not see that! With the fix in the model I was able to train Pegasus with the standard Trainer. |
Good point, I remember that through me off because it explicitly says -100 works in the model's docstring. |
|
I updated the docstring and added two assertions. Are these the assertions you were looking for @patil-suraj ? |
|
|
||
| if labels is not None: | ||
| assert -100 not in labels | ||
| assert labels.min() > 0, f'negative labels are not supported, got {labels.min()}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be labels.min()>=0, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes good catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is something we can accept in the forward pass of the BART model, as it would severely harm TPU performance. The assertion means that we would be retrieving the value of the xla tensor labels.min() > 0 back on CPU every time, which would cause a big performance drop.
I would advocate for this to be put in the dataloader instead, and the loss will crash anyway when seeing a label value which has a negative value and is not the ignored index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aargh, we completely ignored TPU, makes sense. Thanks Lysandre!
|
I am not in favor of this PR to be honest.
Already discussed offline with @sshleifer. What are your thoughts on this @LysandreJik @sgugger @thomwolf ? |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think we should force model to only ignore pad_tokens in loss
|
I agree that it would be nice to have a uniform pattern across the model architectures allowing to use the models interchangeably. It seems there is some work needed to make this allow /usr/local/lib/python3.6/dist-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask, output_attentions, output_hidden_states, return_dict)
332 attention_mask = invert_mask(attention_mask)
333
--> 334 inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
335 embed_pos = self.embed_positions(input_ids)
336 x = inputs_embeds + embed_pos
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
112 return F.embedding(
113 input, self.weight, self.padding_idx, self.max_norm,
--> 114 self.norm_type, self.scale_grad_by_freq, self.sparse)
115
116 def extra_repr(self):
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1482 # remove once script supports set_grad_enabled
1483 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1485
1486
RuntimeError: index out of range: Tried to access index -100 out of table with 50263 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418 |
|
@lvwerra I think we should ignore |
|
@sshleifer I tried to pass from transformers import BartForConditionalGeneration, BartTokenizer
import torch
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
model(torch.tensor([[0, -100]]))But I get the same error with: t =model(torch.tensor([[0, 1]]), decoder_input_ids=torch.tensor([[0, 0, -100]]))It seems that the error comes from the line |
|
I think to successfully implement the -100 strategy (Which I have never done), |
|
Yes I would tend to agree with @patrickvonplaten, I think that the usual philosophy of the lib is that we let the user handle this himself and have clear and simple exemple which shows that you should replace pad_token ids with ignore index in the labels. |
|
I somehow missed the notification when @patrickvonplaten asked for advice earlier but I agree with what he said. We only handle a basic loss computation inside the model. We refused PRs to add weights for cross-entropy recently, for the same reason @thomwolf just pointed out: anything fancier should be done by the user themself, as we can't support every use case. For the |
|
Thanks for the feedback @thomwolf & @sgugger! From a user perspective, I think it would be great if one could use a model in combination with the If you want to go the |
|
Hey @lvwerra, I think the main arguments against ignoring the
|
|
Ok, so if I understand correctly the minimal example to train a Bart model given a from transformers.modeling_bart import shift_tokens_right
def convert_to_features(example_batch):
input_encodings = tokenizer.batch_encode_plus(example_batch['text'], pad_to_max_length=True)
target_encodings = tokenizer.batch_encode_plus(example_batch['summary'], pad_to_max_length=True)
labels = target_encodings['input_ids']
decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
labels[labels[:, :] == 0] = -100
encodings = {
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'decoder_input_ids': decoder_input_ids,
'labels': labels,
}
return encodingsIt took me quite a while reading examples and code reading to figure this out. Not only the thing with the padding tokens and -100 but also the difference between |
|
you could write a forums post and link to it from bart.rst? |
What does this PR do?
There is a discrepancy between the fine-tuning script and the BartForConditionalGeneration which is also noted in the comments.
From
examples/seq2seq/finetune.py:From
transformers/src/modeling_bart.py:Training with the
TrainerandBartForConditionalGenerationresults in a model that produces garbled text (lots of repetitions and no coherence). Adding theignore_index=self.config.pad_token_idin theCrossEntropyLossresolves the issue.Besides a before and after run I did not study the behaviour in a systematic way since training the model requires a significant amount of time and compute. If you would like to see more testing let me know what you think is the best way to test this thoroughly.